#!/usr/bin/env python2
#-*- coding: utf-8 -*-

import traceback

from Ump import utils
from Ump.common import log
from Ump.common import exception

from Ump.objs.session_wrapper import enable_log_and_session, _sw
from Ump.objs.db import models
from Ump.objs.manager_base import Manager
from Ump.objs.snapshot.manager import SnapshotManager

from Ump.lich.cgsnapshot import LichCGSnapshot, LichCGSnapshotParam

LOG = log.get_log('Ump.objs.cgsnapshot.manager')


@models.add_model(models.CGSnapshot)
class CGSnapshotManager(Manager):

    def __init__(self):
        super(CGSnapshotManager, self).__init__()
        self.lichCGSnapshot = LichCGSnapshot()
        self.snapManager = SnapshotManager()

    def _get_vols(self, volumes):
        vols = ['%s/%s' % (volume.pool.realname, volume.name) for volume in volumes]
        return vols

    @enable_log_and_session(resource='cgsnapshot', event='create')
    def create(self, _logger, kwargs):
        path = kwargs['cgsnapshot_name']
        _logger.set_obj(path)

        cgsnapshot = self._create(**kwargs)
        return cgsnapshot
            
    def _create(self, cluster_id=None, username='admin', **kwargs):
        cluster = _sw.get_cluster(cluster_id=cluster_id)
        kwargs['username'] = username
        name = kwargs.get('cgsnapshot_name')
        user = _sw.get_user(username=username)
        vgroup = _sw.get_vgroup(kwargs)
        volumes = vgroup.volumes
        if not volumes:
            raise Exception("卷组不能为空")
        cgsnapshot = _sw.get_cgsnapshot(kwargs)
        if cgsnapshot:
            raise Exception("指定的名称已被使用")
        # vgroup = _sw.get_one(models.VGroup, id_or_spec=vgroup_id)
        # if not vgroup:
        #     raise exception.VGroupNotFound(_id=vgroup_id)

        values = {
            'user_id': user.id,
            'name': name,
            'vgroup_id': vgroup.id,
        }

        host_ip = self._select_http(cluster.id)
        vols = self._get_vols(vgroup.volumes)
        if not vols:
            raise exception.VolumeNotFound()
        protocol = vgroup.volumes[0].protocol

        param = LichCGSnapshotParam(host_ip=host_ip, vols=vols, snap_name=name, protocol=protocol)
        res = self.lichCGSnapshot.create(param)

        cgsnapshot = models.CGSnapshot(values).save()
        return cgsnapshot 

    @enable_log_and_session(resource='cgsnapshot', event='delete')
    def delete(self, _logger, kwargs):
        if kwargs.get('id'):
            cgsnapshot_id = kwargs.get('id')
            cgsnapshot = self._get_one(cgsnapshot_id)
        else:
            cgsnapshot = _sw.get_cgsnapshot(kwargs)
        if not cgsnapshot:
            raise exception.CGSnapshotNotFound(kwargs)

        _logger.set_obj(cgsnapshot.name)

        volumes = cgsnapshot.vgroup.volumes
        host_ip = self._select_http(volumes[0].cluster_id)
        vols = self._get_vols(volumes)
        protocol = volumes[0].protocol

        param = LichCGSnapshotParam(host_ip=host_ip, vols=vols, snap_name=cgsnapshot.name, protocol=protocol)
        res = self.lichCGSnapshot.delete(param)   

        cgsnapshot.delete()
        #self.sync_cgsnapshot()
        return cgsnapshot

    @enable_log_and_session(resource='cgsnapshot', event='rollback')
    def rollback(self, _logger, kwargs):
        cgsnapshot_id = kwargs['id']
        cgsnapshot = self._get_one(cgsnapshot_id)

        volumes = cgsnapshot.vgroup.volumes
        host_ip = self._select_http(volumes[0].cluster_id)
        vols = self._get_vols(volumes)
        protocol = volumes[0].protocol

        param = LichCGSnapshotParam(host_ip=host_ip, vols=vols, snap_name=cgsnapshot.name, protocol=protocol)
        res = self.lichCGSnapshot.rollback(param)   
        
        #self.sync_cgsnapshot()
        return cgsnapshot

    def _get_one(self, cgsnapshot_id):
        cgsnapshot = _sw.get_one(models.CGSnapshot, cgsnapshot_id)
        if not cgsnapshot: 
            raise exception.CGSnapshotNotFound(cgsnapshot_id=cgsnapshot_id)

        return cgsnapshot
   
    def clone(self, kwargs):
        pass

    def sync_cgsnapshot(self, cluster_id=1):
        vgroups = _sw.db_vgroups()
        for vgroup in vgroups:
            self._sync_cgsnapshot_with_vgroup(_sw, vgroup)

    def _sync_cgsnapshot_with_vgroup(self, vgroup):

        LOG.info('===================')
        LOG.info(vgroup)
        LOG.info('===================')
        if not vgroup.volumes:
            return 

        # TODO 
        vols = self._get_vols(vgroup.volumes)
        protocol = vgroup.volumes[0].protocol

        param = LichCGSnapshotParam(cluster_id=vgroup.volumes[0].cluster_id, vols=vols, protocol=protocol)
        lich_cgsnapshots = self.lichCGSnapshot.list(param)

        for dbcgsnapshot in vgroup.cgsnapshots:
            if dbcgsnapshot.name in lich_cgsnapshots:
                continue
            dbcgsnapshot.delete()

        return


if __name__ == '__main__':
    snm = CGSnapshotManager()
    snm.sync_cgsnapshot()
